{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Installing packages:\n",
      "\t.package(path: \"/home/sgugger/git/course-v3/nbs/swift/FastaiNotebook_06_cuda\")\n",
      "\t\tFastaiNotebook_06_cuda\n",
      "With SwiftPM flags: []\n",
      "Working in: /tmp/tmp6kdf0qjm/swift-install\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "Updating https://github.com/mxcl/Path.swift\n",
      "Updating https://github.com/saeta/Just\n",
      "Updating https://github.com/latenitesoft/NotebookExport\n",
      "Completed resolution in 2.03s\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "/home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)[1/11] Compiling FastaiNotebook_06_cuda 01_matmul.swift\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "[2/11] Compiling FastaiNotebook_06_cuda 03_minibatch_training.swift\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "[3/11] Compiling FastaiNotebook_06_cuda 02_fully_connected.swift\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "[4/11] Compiling FastaiNotebook_06_cuda 05b_early_stopping.swift\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "[5/11] Compiling FastaiNotebook_06_cuda 06_cuda.swift\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "[6/11] Compiling FastaiNotebook_06_cuda 05_anneal.swift\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "[7/11] Compiling FastaiNotebook_06_cuda 02a_why_sqrt5.swift\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "[8/11] Compiling FastaiNotebook_06_cuda 00_load_data.swift\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "[9/11] Compiling FastaiNotebook_06_cuda 04_callbacks.swift\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "[10/11] Compiling FastaiNotebook_06_cuda 01a_fastai_layers.swift\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "[11/12] Merging module FastaiNotebook_06_cuda\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "/home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)[12/13] Compiling jupyterInstalledPackages jupyterInstalledPackages.swift\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "[13/14] Merging module jupyterInstalledPackages\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift-autolink-extract: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift-autolink-extract)\n",
      "[14/14] Linking libjupyterInstalledPackages.so\n",
      "Initializing Swift...\n",
      "Installation complete!\n"
     ]
    }
   ],
   "source": [
    "%install-location $cwd/swift-install\n",
    "%install '.package(path: \"$cwd/FastaiNotebook_06_cuda\")' FastaiNotebook_06_cuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "import Path\n",
    "import TensorFlow\n",
    "import Python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import FastaiNotebook_06_cuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('inline', 'module://ipykernel.pylab.backend_inline')\n"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%include \"EnableIPythonDisplay.swift\"\n",
    "IPythonDisplay.shell.enable_matplotlib(\"inline\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "let data = mnistDataBunch(flat: false, bs: 512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "func optFunc(_ model: CnnModel) -> SGD<CnnModel> { return SGD(for: model, learningRate: 0.4) }\n",
    "func modelInit() -> CnnModel { return CnnModel(channelIn: 1, nOut: 10, filters: [8, 16, 32, 32]) }\n",
    "let learner = Learner(data: data, lossFunc: softmaxCrossEntropy, optFunc: optFunc, modelInit: modelInit)\n",
    "let recorder = learner.makeDefaultDelegates(metrics: [accuracy])\n",
    "learner.addDelegates([learner.makeNormalize(mean: mnistStats.mean, std: mnistStats.std),\n",
    "                      learner.makeAddChannel()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: [0.48433208, 0.8463]                                                  \n",
      "average: 4766.350142 ms,   min: 4766.350142 ms,   max: 4766.350142 ms      \n"
     ]
    }
   ],
   "source": [
    "time { try! learner.fit(1) }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batchnorm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Custom"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start by building our own `BatchNorm` layer from scratch. Eventually we intend for this code to do the trick:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "struct AlmostBatchNorm<Scalar: TensorFlowFloatingPoint>: Differentiable {\n",
    "    // Configuration hyperparameters\n",
    "    @noDerivative let momentum, epsilon: Scalar\n",
    "    // Running statistics\n",
    "    @noDerivative var runningMean, runningVariance: Tensor<Scalar>\n",
    "    // Trainable parameters\n",
    "    var scale, offset: Tensor<Scalar>\n",
    "    \n",
    "    init(featureCount: Int, momentum: Scalar = 0.9, epsilon: Scalar = 1e-5) {\n",
    "        self.momentum = momentum\n",
    "        self.epsilon = epsilon\n",
    "        self.scale = Tensor(ones: [featureCount])\n",
    "        self.offset = Tensor(zeros: [featureCount])\n",
    "        self.runningMean = Tensor(0)\n",
    "        self.runningVariance = Tensor(1)\n",
    "    }\n",
    "\n",
    "    mutating func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {\n",
    "        let mean: Tensor<Scalar>\n",
    "        let variance: Tensor<Scalar>\n",
    "        switch Context.local.learningPhase {\n",
    "        case .training:\n",
    "            mean = input.mean(alongAxes: [0, 1, 2])\n",
    "            variance = input.variance(alongAxes: [0, 1, 2])\n",
    "            runningMean += (mean - runningMean) * (1 - momentum)\n",
    "            runningVariance += (variance - runningVariance) * (1 - momentum)\n",
    "        case .inference:\n",
    "            mean = runningMean\n",
    "            variance = runningVariance\n",
    "        }\n",
    "        let normalizer = rsqrt(variance + epsilon) * scale\n",
    "        return (input - mean) * normalizer + offset\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But there are some automatic differentiation limitations (control flow support) and `Layer` protocol constraints (mutating `call`) that make this impossible for now (note the lack of `@differentiable` or a `Layer` conformance), so we'll need a few workarounds. A `Reference` will let us update running statistics without declaring the `applied` method `mutating`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "public class Reference<T> {\n",
    "    public var value: T\n",
    "    public init(_ value: T) { self.value = value }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following snippet will let us differentiate a layer's `call` method if it's composed of training and inference implementations that are each differentiable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "public protocol LearningPhaseDependent: FALayer {\n",
    "    associatedtype Input\n",
    "    associatedtype Output\n",
    "    \n",
    "    @differentiable func forwardTraining(_ input: Input) -> Output\n",
    "    @differentiable func forwardInference(_ input: Input) -> Output\n",
    "}\n",
    "\n",
    "extension LearningPhaseDependent {\n",
    "    // This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer\n",
    "    // protocol requirement, even though there is a `@differentiating(forward)` method below.\n",
    "    // TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,\n",
    "    // some require it. Investigate.\n",
    "    @differentiable\n",
    "    public func forward(_ input: Input) -> Output {\n",
    "        switch Context.local.learningPhase {\n",
    "        case .training:  return forwardTraining(input)\n",
    "        case .inference: return forwardInference(input)\n",
    "        }\n",
    "    }\n",
    "\n",
    "    @differentiating(forward)\n",
    "    func gradForward(_ input: Input) ->\n",
    "        (value: Output, pullback: (Self.Output.TangentVector) ->\n",
    "            (Self.TangentVector, Self.Input.TangentVector)) {\n",
    "        switch Context.local.learningPhase {\n",
    "        case .training:\n",
    "            return valueWithPullback(at: input) { $0.forwardTraining ($1) }\n",
    "        case .inference:\n",
    "            return valueWithPullback(at: input) { $0.forwardInference($1) }\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can implement a BatchNorm that we can use in our models:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{\n",
    "    associatedtype Scalar\n",
    "    init(featureCount: Int, epsilon: Scalar)\n",
    "}\n",
    "\n",
    "public struct FABatchNorm<Scalar: TensorFlowFloatingPoint>: LearningPhaseDependent, Norm {\n",
    "    // TF-603 workaround.\n",
    "    public typealias Input = Tensor<Scalar>\n",
    "    public typealias Output = Tensor<Scalar>\n",
    "    @noDerivative public var delegates: [(Self.Output) -> ()] = []\n",
    "    \n",
    "    // Configuration hyperparameters\n",
    "    @noDerivative var momentum, epsilon: Scalar\n",
    "    // Running statistics\n",
    "    @noDerivative let runningMean, runningVariance: Reference<Tensor<Scalar>>\n",
    "    // Trainable parameters\n",
    "    public var scale, offset: Tensor<Scalar>\n",
    "    \n",
    "    public init(featureCount: Int, momentum: Scalar, epsilon: Scalar = 1e-5) {\n",
    "        self.momentum = momentum\n",
    "        self.epsilon = epsilon\n",
    "        self.scale = Tensor(ones: [featureCount])\n",
    "        self.offset = Tensor(zeros: [featureCount])\n",
    "        self.runningMean = Reference(Tensor(0))\n",
    "        self.runningVariance = Reference(Tensor(1))\n",
    "    }\n",
    "    \n",
    "    public init(featureCount: Int, epsilon: Scalar = 1e-5) {\n",
    "        self.init(featureCount: featureCount, momentum: 0.9, epsilon: epsilon)\n",
    "    }\n",
    "\n",
    "    @differentiable\n",
    "    public func forwardTraining(_ input: Tensor<Scalar>) -> Tensor<Scalar> {\n",
    "        let mean = input.mean(alongAxes: [0, 1, 2])\n",
    "        let variance = input.variance(alongAxes: [0, 1, 2])\n",
    "        runningMean.value += (mean - runningMean.value) * (1 - momentum)\n",
    "        runningVariance.value += (variance - runningVariance.value) * (1 - momentum)\n",
    "        let normalizer = rsqrt(variance + epsilon) * scale\n",
    "        return (input - mean) * normalizer + offset\n",
    "    }\n",
    "    \n",
    "    @differentiable\n",
    "    public func forwardInference(_ input: Tensor<Scalar>) -> Tensor<Scalar> {\n",
    "        let mean = runningMean.value\n",
    "        let variance = runningVariance.value\n",
    "        let normalizer = rsqrt(variance + epsilon) * scale\n",
    "        return (input - mean) * normalizer + offset\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TensorFlow provides a highly optimized batch norm implementation, let us redefine our batch norm to invoke it directly. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "struct BatchNormResult<Scalar : TensorFlowFloatingPoint> : Differentiable{\n",
    "    var y, batchMean, batchVariance, reserveSpace1, reserveSpace2: Tensor<Scalar>\n",
    "}\n",
    "\n",
    "public struct TFBatchNorm<Scalar: TensorFlowFloatingPoint>: LearningPhaseDependent, Norm {\n",
    "    // Configuration hyperparameters\n",
    "    @noDerivative var momentum, epsilon: Scalar\n",
    "    // Running statistics\n",
    "    @noDerivative let runningMean, runningVariance: Reference<Tensor<Scalar>>\n",
    "    // Trainable parameters\n",
    "    public var scale, offset: Tensor<Scalar>\n",
    "    @noDerivative public var delegates: [(Self.Output) -> ()] = []\n",
    "    \n",
    "    public init(featureCount: Int, momentum: Scalar, epsilon: Scalar = 1e-5) {\n",
    "        self.momentum = momentum\n",
    "        self.epsilon = epsilon\n",
    "        self.scale = Tensor(ones: [featureCount])\n",
    "        self.offset = Tensor(zeros: [featureCount])\n",
    "        self.runningMean = Reference(Tensor(0))\n",
    "        self.runningVariance = Reference(Tensor(1))\n",
    "    }\n",
    "    \n",
    "    public init(featureCount: Int, epsilon: Scalar = 1e-5) {\n",
    "        self.init(featureCount: featureCount, momentum: 0.9, epsilon: epsilon)\n",
    "    }\n",
    "\n",
    "    @differentiable\n",
    "    public func forwardTraining(_ input: Tensor<Scalar>) -> Tensor<Scalar> {\n",
    "        let res = TFBatchNorm<Scalar>.fusedBatchNorm(\n",
    "            input, scale: scale, offset: offset, epsilon: epsilon)\n",
    "        let (output, mean, variance) = (res.y, res.batchMean, res.batchVariance)\n",
    "        runningMean.value += (mean - runningMean.value) * (1 - momentum)\n",
    "        runningVariance.value += (variance - runningVariance.value) * (1 - momentum)\n",
    "        return output\n",
    "     }\n",
    "    \n",
    "    @differentiable\n",
    "    public func forwardInference(_ input: Tensor<Scalar>) -> Tensor<Scalar> {\n",
    "        let mean = runningMean.value\n",
    "        let variance = runningVariance.value\n",
    "        let normalizer = rsqrt(variance + epsilon) * scale\n",
    "        return (input - mean) * normalizer + offset\n",
    "    }\n",
    "    \n",
    "    @differentiable(wrt: (x, scale, offset), vjp: _vjpFusedBatchNorm)\n",
    "    static func fusedBatchNorm(\n",
    "        _ x : Tensor<Scalar>, scale: Tensor<Scalar>, offset: Tensor<Scalar>, epsilon: Scalar\n",
    "    ) -> BatchNormResult<Scalar> {\n",
    "        let ret = Raw.fusedBatchNormV2(\n",
    "            x, scale: scale, offset: offset, \n",
    "            mean: Tensor<Scalar>([] as [Scalar]), variance: Tensor<Scalar>([] as [Scalar]), \n",
    "            epsilon: Double(epsilon))\n",
    "        return BatchNormResult(\n",
    "            y: ret.y, batchMean: ret.batchMean, batchVariance: ret.batchVariance,\n",
    "            reserveSpace1: ret.reserveSpace1, reserveSpace2: ret.reserveSpace2\n",
    "        )\n",
    "    }\n",
    "\n",
    "    static func _vjpFusedBatchNorm(\n",
    "        _ x : Tensor<Scalar>, scale: Tensor<Scalar>, offset: Tensor<Scalar>, epsilon: Scalar\n",
    "    ) -> (BatchNormResult<Scalar>, \n",
    "          (BatchNormResult<Scalar>.TangentVector) -> (Tensor<Scalar>.TangentVector, \n",
    "                                                        Tensor<Scalar>.TangentVector, \n",
    "                                                        Tensor<Scalar>.TangentVector)) {\n",
    "      let bnresult = fusedBatchNorm(x, scale: scale, offset: offset, epsilon: epsilon)\n",
    "  \n",
    "        return (\n",
    "            bnresult, \n",
    "            {v in \n",
    "                let res = Raw.fusedBatchNormGradV2(\n",
    "                    yBackprop: v.y, x, scale: Tensor<Float>(scale), \n",
    "                    reserveSpace1: bnresult.reserveSpace1, \n",
    "                    reserveSpace2: bnresult.reserveSpace2, \n",
    "                    epsilon: Double(epsilon))\n",
    "                return (res.xBackprop, res.scaleBackprop, res.offsetBackprop)\n",
    "            })\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "public struct ConvBN<Scalar: TensorFlowFloatingPoint>: FALayer {\n",
    "    // TF-603 workaround.\n",
    "    public typealias Input = Tensor<Scalar>\n",
    "    public typealias Output = Tensor<Scalar>\n",
    "    @noDerivative public var delegates: [(Self.Output) -> ()] = []\n",
    "    public var conv: FANoBiasConv2D<Scalar>\n",
    "    public var norm: FABatchNorm<Scalar>\n",
    "    \n",
    "    public init(_ cIn: Int, _ cOut: Int, ks: Int = 3, stride: Int = 1){\n",
    "        // TODO (when control flow AD works): use Conv2D without bias\n",
    "        self.conv = FANoBiasConv2D(cIn, cOut, ks: ks, stride: stride, activation: relu)\n",
    "        self.norm = FABatchNorm(featureCount: cOut, epsilon: 1e-5)\n",
    "    }\n",
    "\n",
    "    @differentiable\n",
    "    public func forward(_ input: Tensor<Scalar>) -> Tensor<Scalar> {\n",
    "        return norm.forward(conv.forward(input))\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Would be great if this generic could work\n",
    "// struct ConvNorm<NormType: Norm, Scalar: TensorFlowFloatingPoint>: Layer\n",
    "//     where NormType.Scalar == Scalar {\n",
    "//     var conv: Conv2D<Scalar>\n",
    "//     var norm: NormType\n",
    "//     init(\n",
    "//         filterShape: (Int, Int, Int, Int),\n",
    "//         strides: (Int, Int) = (1, 1),\n",
    "//         padding: Padding = .valid,\n",
    "//         activation: @escaping Conv2D<Scalar>.Activation = identity\n",
    "//     ) {\n",
    "//         // TODO (when control flow AD works): use Conv2D without bias\n",
    "//         self.conv = Conv2D(\n",
    "//             filterShape: filterShape,\n",
    "//             strides: strides,\n",
    "//             padding: padding,\n",
    "//             activation: activation)\n",
    "//         self.norm = NormType.init(featureCount: filterShape.3, epsilon: 1e-5)\n",
    "//     }\n",
    "\n",
    "//     @differentiable\n",
    "//     func applied(to input: Tensor<Scalar>) -> Tensor<Scalar> {\n",
    "//         return norm.applied(to: conv.applied(to: input))\n",
    "//     }\n",
    "// }\n",
    "//typealias ConvBN = ConvNorm<BatchNorm<Float>, Float>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "public struct CnnModelBN: Layer {\n",
    "    public var convs: [ConvBN<Float>]\n",
    "    public var pool = FAGlobalAvgPool2D<Float>()\n",
    "    public var linear: FADense<Float>\n",
    "    @noDerivative public var delegates: [(Self.Output) -> ()] = []\n",
    "    \n",
    "    public init(channelIn: Int, nOut: Int, filters: [Int]){\n",
    "        let allFilters = [channelIn] + filters\n",
    "        convs = Array(0..<filters.count).map { i in\n",
    "            return ConvBN(allFilters[i], allFilters[i+1], ks: 3, stride: 2)\n",
    "        }\n",
    "        linear = FADense<Float>(filters.last!, nOut)\n",
    "    }\n",
    "    \n",
    "    @differentiable\n",
    "    public func callAsFunction(_ input: TF) -> TF {\n",
    "        // TODO: Work around https://bugs.swift.org/browse/TF-606\n",
    "        return linear.forward(pool.forward(convs(input)))\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "func optFunc(_ model: CnnModelBN) -> SGD<CnnModelBN> { return SGD(for: model, learningRate: 0.4) }\n",
    "func modelInit() -> CnnModelBN { return CnnModelBN(channelIn: 1, nOut: 10, filters: [8, 16, 32, 32]) }\n",
    "let learner = Learner(data: data, lossFunc: softmaxCrossEntropy, optFunc: optFunc, modelInit: modelInit)\n",
    "let recorder = learner.makeDefaultDelegates(metrics: [accuracy])\n",
    "learner.addDelegates([learner.makeNormalize(mean: mnistStats.mean, std: mnistStats.std),\n",
    "                      learner.makeAddChannel()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: [nan(0x1fffff), 0.098]                                                \n",
      "average: 2481.692572 ms,   min: 2481.692572 ms,   max: 2481.692572 ms      \n"
     ]
    }
   ],
   "source": [
    "time { try! learner.fit(1) }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## More norms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Layer norm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From [the paper](https://arxiv.org/abs/1607.06450): \"*batch normalization cannot be applied to online learning tasks or to extremely large distributed models where the minibatches have to be small*\"."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "General equation for a norm layer with learnable affine:\n",
    "\n",
    "$$y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta$$\n",
    "\n",
    "The difference with BatchNorm is\n",
    "1. we don't keep a moving average\n",
    "2. we don't average over the batches dimension but over the hidden dimension, so it's independent of the batch size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "struct LayerNorm2D<Scalar: TensorFlowFloatingPoint>: Norm {\n",
    "    @noDerivative public var delegates: [(Self.Output) -> ()] = []\n",
    "    // Configuration hyperparameters\n",
    "    @noDerivative let epsilon: Scalar\n",
    "    // Trainable parameters\n",
    "    var scale: Tensor<Scalar>\n",
    "    var offset: Tensor<Scalar>\n",
    "    \n",
    "    init(featureCount: Int, epsilon: Scalar = 1e-5) {\n",
    "        self.epsilon = epsilon\n",
    "        self.scale = Tensor(ones: [featureCount])\n",
    "        self.offset = Tensor(zeros: [featureCount])\n",
    "    }\n",
    "    \n",
    "    @differentiable\n",
    "    func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {\n",
    "        let mean = input.mean(alongAxes: [1, 2, 3])\n",
    "        let variance = input.variance(alongAxes: [1, 2, 3])\n",
    "        let normalizer = rsqrt(variance + epsilon) * scale\n",
    "        return (input - mean) * normalizer + offset\n",
    "    }\n",
    "}\n",
    "\n",
    "struct ConvLN<Scalar: TensorFlowFloatingPoint>: FALayer {\n",
    "    @noDerivative public var delegates: [(Self.Output) -> ()] = []\n",
    "    var conv: FANoBiasConv2D<Scalar>\n",
    "    var norm: LayerNorm2D<Scalar>\n",
    "    \n",
    "    init(_ cIn: Int, _ cOut: Int, ks: Int = 3, stride: Int = 2){\n",
    "        // TODO (when control flow AD works): use Conv2D without bias\n",
    "        self.conv = FANoBiasConv2D(cIn, cOut, ks: ks, stride: stride, activation: relu)\n",
    "        self.norm = LayerNorm2D(featureCount: cOut, epsilon: 1e-5)\n",
    "    }\n",
    "\n",
    "    @differentiable\n",
    "    func forward(_ input: Tensor<Scalar>) -> Tensor<Scalar> {\n",
    "        return norm.callAsFunction(conv.forward(input))\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public struct CnnModelLN: Layer {\n",
    "    public var convs: [ConvLN<Float>]\n",
    "    public var pool = FAGlobalAvgPool2D<Float>()\n",
    "    public var linear: FADense<Float>\n",
    "    \n",
    "    public init(channelIn: Int, nOut: Int, filters: [Int]){\n",
    "        let allFilters = [channelIn] + filters\n",
    "        convs = Array(0..<filters.count).map { i in\n",
    "            return ConvLN(allFilters[i], allFilters[i+1], ks: 3, stride: 2)\n",
    "        }\n",
    "        linear = FADense<Float>(filters.last!, nOut)\n",
    "    }\n",
    "    \n",
    "    @differentiable\n",
    "    public func callAsFunction(_ input: TF) -> TF {\n",
    "        // TODO: Work around https://bugs.swift.org/browse/TF-606\n",
    "        return linear.forward(pool.forward(convs(input)))\n",
    "    }\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "struct InstanceNorm<Scalar: TensorFlowFloatingPoint>: Norm {\n",
    "    @noDerivative public var delegates: [(Self.Output) -> ()] = []\n",
    "    // Configuration hyperparameters\n",
    "    @noDerivative let epsilon: Scalar\n",
    "    // Trainable parameters\n",
    "    var scale: Tensor<Scalar>\n",
    "    var offset: Tensor<Scalar>\n",
    "    \n",
    "    init(featureCount: Int, epsilon: Scalar = 1e-5) {\n",
    "        self.epsilon = epsilon\n",
    "        self.scale = Tensor(ones: [featureCount])\n",
    "        self.offset = Tensor(zeros: [featureCount])\n",
    "    }\n",
    "    \n",
    "    @differentiable\n",
    "    func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {\n",
    "        let mean = input.mean(alongAxes: [2, 3])\n",
    "        let variance = input.variance(alongAxes: [2, 3])\n",
    "        let normalizer = rsqrt(variance + epsilon) * scale\n",
    "        return (input - mean) * normalizer + offset\n",
    "    }\n",
    "}\n",
    "\n",
    "struct ConvIN<Scalar: TensorFlowFloatingPoint>: FALayer {\n",
    "    @noDerivative public var delegates: [(Self.Output) -> ()] = []\n",
    "    var conv: FANoBiasConv2D<Scalar>\n",
    "    var norm: InstanceNorm<Scalar>\n",
    "    \n",
    "    init(_ cIn: Int, _ cOut: Int, ks: Int = 3, stride: Int = 2){\n",
    "        // TODO (when control flow AD works): use Conv2D without bias\n",
    "        self.conv = FANoBiasConv2D(cIn, cOut, ks: ks, stride: stride, activation: relu)\n",
    "        self.norm = InstanceNorm(featureCount: cOut, epsilon: 1e-5)\n",
    "    }\n",
    "\n",
    "    @differentiable\n",
    "    func forward(_ input: Tensor<Scalar>) -> Tensor<Scalar> {\n",
    "        return norm.callAsFunction(conv.forward(input))\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lost in all those norms? The authors from the [group norm paper](https://arxiv.org/pdf/1803.08494.pdf) have you covered:\n",
    "\n",
    "![Various norms](../dev_course/dl2/images/norms.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TODO/skipping GroupNorm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Running Batch Norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "struct RunningBatchNorm<Scalar: TensorFlowFloatingPoint>: LearningPhaseDependent, Norm {\n",
    "    @noDerivative public var delegates: [(Self.Output) -> ()] = []\n",
    "    // Configuration hyperparameters\n",
    "    @noDerivative let momentum: Scalar\n",
    "    @noDerivative let epsilon: Scalar\n",
    "    // Running statistics\n",
    "    @noDerivative let runningSum: Reference<Tensor<Scalar>>\n",
    "    @noDerivative let runningSumOfSquares: Reference<Tensor<Scalar>>\n",
    "    @noDerivative let runningCount: Reference<Scalar>\n",
    "    @noDerivative let samplesSeen: Reference<Int>\n",
    "    // Trainable parameters\n",
    "    var scale: Tensor<Scalar>\n",
    "    var offset: Tensor<Scalar>\n",
    "    \n",
    "    init(featureCount: Int, momentum: Scalar, epsilon: Scalar = 1e-5) {\n",
    "        self.momentum = momentum\n",
    "        self.epsilon = epsilon\n",
    "        self.scale = Tensor(ones: [featureCount])\n",
    "        self.offset = Tensor(zeros: [featureCount])\n",
    "        self.runningSum = Reference(Tensor(0))\n",
    "        self.runningSumOfSquares = Reference(Tensor(0))\n",
    "        self.runningCount = Reference(Scalar(0))\n",
    "        self.samplesSeen = Reference(0)\n",
    "    }\n",
    "    \n",
    "    init(featureCount: Int, epsilon: Scalar = 1e-5) {\n",
    "        self.init(featureCount: featureCount, momentum: 0.9, epsilon: epsilon)\n",
    "    }\n",
    "\n",
    "    @differentiable\n",
    "    func forwardTraining(_ input: Tensor<Scalar>) -> Tensor<Scalar> {\n",
    "        let (batch, channels) = (input.shape[0], Scalar(input.shape[3]))\n",
    "        let sum = input.sum(alongAxes: [0, 1, 2])\n",
    "        let sumOfSquares = (input * input).sum(alongAxes: [0, 1, 2])\n",
    "        // TODO: Work around https://bugs.swift.org/browse/TF-607\n",
    "        let count = withoutDerivative(at: Scalar(input.scalarCount)) { tmp in tmp } / channels\n",
    "        let mom = momentum / sqrt(Scalar(batch) - 1)\n",
    "        let runningSum = mom * self.runningSum.value + (1 - mom) * sum\n",
    "        let runningSumOfSquares = mom * self.runningSumOfSquares.value + (\n",
    "            1 - mom) * sumOfSquares\n",
    "        let runningCount = mom * self.runningCount.value + (1 - mom) * count\n",
    "        \n",
    "        self.runningSum.value = runningSum\n",
    "        self.runningSumOfSquares.value = runningSumOfSquares\n",
    "        self.runningCount.value = runningCount\n",
    "        self.samplesSeen.value += batch\n",
    "        \n",
    "        let mean = runningSum / runningCount\n",
    "        let variance = runningSumOfSquares / runningCount - mean * mean\n",
    "        \n",
    "        let normalizer = rsqrt(variance + epsilon) * scale\n",
    "        return (input - mean) * normalizer + offset\n",
    "    }\n",
    "    \n",
    "    @differentiable\n",
    "    func forwardInference(_ input: Tensor<Scalar>) -> Tensor<Scalar> {\n",
    "        let mean = runningSum.value / runningCount.value\n",
    "        let variance = runningSumOfSquares.value / runningCount.value - mean * mean\n",
    "        let normalizer = rsqrt(variance + epsilon) * scale\n",
    "        return (input - mean) * normalizer + offset\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TODO: XLA compilation + test RBN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "success\r\n"
     ]
    }
   ],
   "source": [
    "import NotebookExport\n",
    "let exporter = NotebookExport(Path.cwd/\"07_batchnorm.ipynb\")\n",
    "print(exporter.export(usingPrefix: \"FastaiNotebook_\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Swift",
   "language": "swift",
   "name": "swift"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
